"""
utils function

author: wxz
date: 2021-11-29
"""

import math
import numpy as np
from scipy import signal
from matplotlib import pylab as plt


# -----------------------------------------------
#                   math tools
# -----------------------------------------------
def distance(a, b):
    """
    Euclidean distance
    """
    dx = a[0] - b[0]
    dy = a[1] - b[1]
    return math.sqrt(dx ** 2 + dy ** 2)


def gauss_kernel(kernel_size=(3, 3), sigma=1):
    r = np.floor(np.float32(kernel_size) / 2.0)
    x, y = np.meshgrid(np.linspace(-r[0], r[0], int(kernel_size[0])),
                       np.linspace(-r[1], r[1], int(kernel_size[1])))

    out = np.exp(-(x ** 2 + y ** 2) / (2.0 * sigma ** 2))
    out = out / np.sum(out)

    return out


# -----------------------------------------------
#               convolution tools
# -----------------------------------------------
def conv2d(img, kernel):
    """
    convolution 2d
    """
    out = signal.convolve2d(img, kernel, mode='same', boundary='symm')
    out = np.asarray(out)
    return out


# -----------------------------------------------
#               image show tools
# -----------------------------------------------

def print_array_params(a):
    print('max:', np.max(a))
    print('min:', np.min(a))
    print('shape:', a.shape)


def imshow(img, title='Image', cmap=None):
    """
    show image
    """

    if cmap == 'gray':
        plt.title(title)
        plt.imshow(img, cmap='gray')
        plt.show()
    elif cmap == 'rggb':
        r = np.zeros_like(img)
        g = np.zeros_like(img)
        b = np.zeros_like(img)
        r[::2, ::2] = img[::2, ::2]
        g[::2, 1::2] = img[::2, 1::2]
        g[1::2, ::2] = img[1::2, ::2]
        b[1::2, 1::2] = img[1::2, 1::2]

        res = np.stack([r, g, b], axis=2)
        res = res / (2 ** 14 - 1)

        plt.title(title)
        plt.imshow(res)
        plt.show()
        pass
    elif cmap == 'rgb':
        img = img
        plt.title(title)
        plt.imshow(img)
        plt.show()
    else:
        img = img
        plt.title(title)
        plt.imshow(img)
        plt.show()

    return


if __name__ == "__main__":
    path = "../datas/DSC_1339_768x512_rggb.raw"
    shape = [512, 768]

    raw_img = np.fromfile(path, dtype='uint16', sep="").reshape(shape)

    raw_img = raw_img / 2 ** 14
    r = np.zeros_like(raw_img)
    g = np.zeros_like(raw_img)
    b = np.zeros_like(raw_img)

    r[::2, ::2] = raw_img[::2, ::2]
    g[::2, 1::2] = raw_img[::2, 1::2]
    g[1::2, ::2] = raw_img[1::2, ::2]
    b[1::2, 1::2] = raw_img[1::2, 1::2]

    imshow(r, title='test_r', cmap='rggb')
    imshow(g, title='test_g', cmap='rggb')
    imshow(b, title='test_b', cmap='rggb')
    imshow(raw_img, title='test_raw', cmap='rggb')
